--- Tests for 'Decimal' numbers
module tests.qc.Decimals where

import Test.QuickCheck as QC (quickCheck, verboseCheck, property, ==>, once, 
                Arbitrary, classify, label, Prop, Gen, forAll)
import Data.Dec64 hiding(>>)
import Data.Dec64 (¦, &)


instance Arbitrary Decimal where
    arbitrary = do  -- Decimal.fromBits <$> arbitrary
        m ← (0xFFffFFffFFffFF00L &) <$> arbitrary
        x ← (Int.shiftR 24) <$> arbitrary
        pure (Decimal.fromBits (m ¦ x.long))   


-- p_dummy = once true

--- For positive numbers, 'divu10' is equivalent to @(`quot` 10)@
p_quotu10 = property g where
    g ∷ Long → Gen Prop
    g n = n != minBound ==> divu10 (abs n) == abs n `quot` 10

--- For positive numbers, 'remu10' is equivalent to @(`rem` 10)@
p_remu10 = property g where
    g ∷ Long → Gen Prop
    g n = n != minBound ==> remu10 (abs n) == abs n `rem` 10

--- 'mulu10' is equivalent to @(*10)@
p_mulu10 = property g where
    g ∷ Long → Bool
    g n = mulu10 n == n*10L
    
--- For positive numbers @p@, @divu10 p * 10 + remu10 p@ is @p@
p_divisionu10 = property g where
    g ∷ Long → Gen Prop
    g n = n != minBound ==> abs n == divu10 (abs n) * 10 + remu10 (abs n)

--- addition is commutative, this should even hold for addition of 'Decimal.nan' 
p_add_commutative = property g where
    g :: Decimal → Decimal → Bool
    g a b = a+b == b+a

--- this holds when one allows 1 'ulp' for rounding errors
p_add_associative = property g where
    g ∷ Decimal → Decimal → Decimal → Gen Prop 
    g a b c = case (a + b) + c of 
        left → case a + (b + c) of
            right | left == right                = label "equal" true
                  | nextUp left == right
                    || nextDown left == right    = label "different by 1 ulp" true
                  | otherwise = label "not associative" true

--- 0 is the neutral element in addition
p_add_zero_neutral = property g where
    g ∷ Decimal → Bool
    g x = x+0z == x

--- @a + (-a) == 0@
p_negate = property g where
    g ∷ Decimal → Gen Prop
    g x = x.coefficient != Decimal.minCoefficient 
            ==> classify (Decimal.isNaN x) "not a number" 
                    (if x.isNaN then true else x + negate x == 0)

--- multiplication is commutative
p_mul_commutative = property g where
    g :: Decimal → Decimal → Bool
    g a b = a*b == b*a

--- not sure if this actually holds
p_mul_associative = property g where
    g ∷ Decimal → Decimal → Decimal → Gen Prop 
    g a b c = case (a * b) * c of 
        left → case a * (b * c) of
            right | left == right                = label "equal" true
                  | nextUp left == right
                    || nextDown left == right    = label "different by 1 ulp" true
                  | otherwise = label "not associative" true

--- multiplication with 'Decimal.zero' is 0
p_mul_zero = property g where
    g ∷ Decimal → Bool
    g a = a*0z == 0z

--- 1 is the neutral element in multiplication
p_mul_one_neutral = property g where
    g ∷ Decimal → Bool
    g a = 1z*a == a
    
--- @x*a@ < @x*(a+1)@
p_mul_increasing = property g where
    g ∷ Decimal → Decimal → Gen Prop
    g x a = not (x.isNaN 
                    || a.isNaN 
                    || x.isZero 
                    || x.coefficient == Decimal.maxCoefficient
                    || a.coefficient == Decimal.maxCoefficient) 
                ==> (abs x * abs a) <= (abs x * (abs a + 1))

--- @x*(1/x)@ is almost 1
p_reciprocal = property g where
    g ∷ Decimal → Gen Prop
    g a = not a.isZero && not a.isNaN ==> label ("error " ++ show err) (err <= 3e-16z)
        where !err = abs (1z - (reciprocal a * a))

--- how many are discarded?
p_nonspecial = property g where
    g ∷ Decimal → Gen Prop
    g a = classify a.isNaN "nan"
            (classify a.isZero "zero" 
                (classify a.coefficient.even "even" true))

rUpto z = [(n, r, e) | n ← [1L..z], 
                    let d = fromIntegral n,
                    let r = reciprocal d,
                    let e = abs (1z - r*d) 
        ]

main = println (fold f (0z,0L) (rUpto Decimal.maxCoefficient))
    where
        f (maxE, wrong) (n,r,e)
            | n `rem` 100000 == 0, traceLn(show n) = undefined
            | otherwise = case max maxE e of
            !maxE' -> case if e.isZero then wrong else wrong +1 of
                !wrong' -> (maxE', wrong')  
